#include<iostream>
#include<algorithm>
#include<cstdio>
#define mod 998244353
#define ll long long
#define p 1000
using namespace std;
ll sum1,sum2,n,m;
int main(){
	freopen("bpmp.in","r",stdin);
	freopen("bpmp.out","w",stdout);
	cin>>n>>m;
	for(ll i=1;i<=n/p;i++)
		sum1=sum1+p*(m-1),sum1=sum1%mod;
	for(int i=1;i<=n%p;i++)sum1=sum1+m-1,sum1=sum1%mod;
	for(ll i=1;i<=m/p;i++)
		sum2=sum2+p*(n-1),sum2=sum2%mod;
	for(int i=1;i<=m%p;i++)sum2=sum2+n-1,sum2=sum2%mod;
	cout<<min(sum1+n-1,sum2+m-1)%mod;
	return 0;
}
